/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.transforms;



import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.KvCoder;

import com.bff.gaia.unified.sdk.state.*;

import com.bff.gaia.unified.sdk.transforms.windowing.BoundedWindow;

import com.bff.gaia.unified.sdk.values.KV;

import com.bff.gaia.unified.sdk.values.PCollection;

import com.bff.gaia.unified.vendor.guava.com.google.common.annotations.VisibleForTesting;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.Iterables;

import com.bff.gaia.unified.sdk.state.BagState;

import com.bff.gaia.unified.sdk.state.CombiningState;

import com.bff.gaia.unified.sdk.state.StateSpec;

import com.bff.gaia.unified.sdk.state.StateSpecs;

import com.bff.gaia.unified.sdk.state.TimeDomain;

import com.bff.gaia.unified.sdk.state.Timer;

import com.bff.gaia.unified.sdk.state.TimerSpec;

import com.bff.gaia.unified.sdk.state.TimerSpecs;

import com.bff.gaia.unified.sdk.state.ValueState;

import org.joda.time.Duration;

import org.joda.time.Instant;

import org.slf4j.Logger;

import org.slf4j.LoggerFactory;



import static com.bff.gaia.unified.vendor.guava.com.google.common.base.Preconditions.checkArgument;



/**

 * A {@link PTransform} that batches inputs to a desired batch size. Batches will contain only

 * elements of a single key.

 *

 * <p>Elements are buffered until there are {@code batchSize} elements buffered, at which point they

 * are output to the output {@link PCollection}.

 *

 * <p>Windows are preserved (batches contain elements from the same window). Batches may contain

 * elements from more than one bundle.

 *

 * <p>Example (batch call a webservice and get return codes):

 *

 * <pre>{@code

 * PCollection<KV<String, String>> input = ...;

 * long batchSize = 100L;

 * PCollection<KV<String, Iterable<String>>> batched = input

 *     .apply(GroupIntoBatches.<String, String>ofSize(batchSize))

 *     .setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of())))

 *     .apply(ParDo.of(new DoFn<KV<String, Iterable<String>>, KV<String, String>>() }{

 *        {@code @ProcessElement

 *         public void processElement(@Element KV<String, Iterable<String>> element,

 *             OutputReceiver<KV<String, String>> r) {

 *             r.output(KV.of(element.getKey(), callWebService(element.getValue())));

 *         }

 *     }}));

 * </pre>

 */

public class GroupIntoBatches<K, InputT>

    extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, Iterable<InputT>>>> {



  private final long batchSize;



  private GroupIntoBatches(long batchSize) {

    this.batchSize = batchSize;

  }



  public static <K, InputT> GroupIntoBatches<K, InputT> ofSize(long batchSize) {

    return new GroupIntoBatches<>(batchSize);

  }



  @Override

  public PCollection<KV<K, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> input) {

    Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness();



    checkArgument(

        input.getCoder() instanceof KvCoder,

        "coder specified in the input PCollection is not a KvCoder");

    KvCoder inputCoder = (KvCoder) input.getCoder();

    Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0);

    Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1);



    return input.apply(

        ParDo.of(new GroupIntoBatchesDoFn<>(batchSize, allowedLateness, keyCoder, valueCoder)));

  }



  @VisibleForTesting

  static class GroupIntoBatchesDoFn<K, InputT>

      extends DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>> {



    private static final Logger LOG = LoggerFactory.getLogger(GroupIntoBatchesDoFn.class);

    private static final String END_OF_WINDOW_ID = "endOFWindow";

    private static final String BATCH_ID = "batch";

    private static final String NUM_ELEMENTS_IN_BATCH_ID = "numElementsInBatch";

    private static final String KEY_ID = "key";

    private final long batchSize;

    private final Duration allowedLateness;



    @TimerId(END_OF_WINDOW_ID)

    private final TimerSpec timer = TimerSpecs.timer(TimeDomain.EVENT_TIME);



    @StateId(BATCH_ID)

    private final StateSpec<BagState<InputT>> batchSpec;



    @StateId(NUM_ELEMENTS_IN_BATCH_ID)

    private final StateSpec<CombiningState<Long, long[], Long>> numElementsInBatchSpec;



    @StateId(KEY_ID)

    private final StateSpec<ValueState<K>> keySpec;



    private final long prefetchFrequency;



    GroupIntoBatchesDoFn(

        long batchSize,

        Duration allowedLateness,

        Coder<K> inputKeyCoder,

        Coder<InputT> inputValueCoder) {

      this.batchSize = batchSize;

      this.allowedLateness = allowedLateness;

      this.batchSpec = StateSpecs.bag(inputValueCoder);

      this.numElementsInBatchSpec =

          StateSpecs.combining(

              new Combine.BinaryCombineLongFn() {



                @Override

                public long identity() {

                  return 0L;

                }



                @Override

                public long apply(long left, long right) {

                  return left + right;

                }

              });



      this.keySpec = StateSpecs.value(inputKeyCoder);

      // prefetch every 20% of batchSize elements. Do not prefetch if batchSize is too little

      this.prefetchFrequency = ((batchSize / 5) <= 1) ? Long.MAX_VALUE : (batchSize / 5);

    }



    @ProcessElement

    public void processElement(

        @TimerId(END_OF_WINDOW_ID) Timer timer,

        @StateId(BATCH_ID) BagState<InputT> batch,

        @StateId(NUM_ELEMENTS_IN_BATCH_ID) CombiningState<Long, long[], Long> numElementsInBatch,

        @StateId(KEY_ID) ValueState<K> key,

        @Element KV<K, InputT> element,

        BoundedWindow window,

        OutputReceiver<KV<K, Iterable<InputT>>> receiver) {

      Instant windowExpires = window.maxTimestamp().plus(allowedLateness);



      LOG.debug(

          "*** SET TIMER *** to point in time {} for window {}",

          windowExpires.toString(),

          window.toString());

      timer.set(windowExpires);

      key.write(element.getKey());

      batch.add(element.getValue());

      LOG.debug("*** BATCH *** Add element for window {} ", window.toString());

      // blind add is supported with combiningState

      numElementsInBatch.add(1L);

      Long num = numElementsInBatch.read();

      if (num % prefetchFrequency == 0) {

        // prefetch data and modify batch state (readLater() modifies this)

        batch.readLater();

      }

      if (num >= batchSize) {

        LOG.debug("*** END OF BATCH *** for window {}", window.toString());

        flushBatch(receiver, key, batch, numElementsInBatch);

      }

    }



    @OnTimer(END_OF_WINDOW_ID)

    public void onTimerCallback(

        OutputReceiver<KV<K, Iterable<InputT>>> receiver,

        @Timestamp Instant timestamp,

        @StateId(KEY_ID) ValueState<K> key,

        @StateId(BATCH_ID) BagState<InputT> batch,

        @StateId(NUM_ELEMENTS_IN_BATCH_ID) CombiningState<Long, long[], Long> numElementsInBatch,

        BoundedWindow window) {

      LOG.debug(

          "*** END OF WINDOW *** for timer timestamp {} in windows {}",

          timestamp,

          window.toString());

      flushBatch(receiver, key, batch, numElementsInBatch);

    }



    private void flushBatch(

        OutputReceiver<KV<K, Iterable<InputT>>> receiver,

        ValueState<K> key,

        BagState<InputT> batch,

        CombiningState<Long, long[], Long> numElementsInBatch) {

      Iterable<InputT> values = batch.read();

      // when the timer fires, batch state might be empty

      if (!Iterables.isEmpty(values)) {

        receiver.output(KV.of(key.read(), values));

      }

      batch.clear();

      LOG.debug("*** BATCH *** clear");

      numElementsInBatch.clear();

    }

  }

}